<?php

namespace DataCube\DataCubeAggregation\AI_Toolkit\Classification;

use DataCube\DataCubeAggregation\AI_Toolkit\Interfaces\ProbabilityEstimator;
use DataCube\DataCubeAggregation\AI_Toolkit\Interfaces\TrainerInterface;
use DataCube\DataCubeAggregation\Exception\CustomException;
use DataCube\DataCubeAggregation\Exception\CustomInvalidArgumentException;
use Symfony\Component\OptionsResolver\OptionsResolver;
use Rubix\ML\Classifiers\AdaBoost as RubixAdaBoost;
use Rubix\ML\Learner;
use Rubix\ML\Datasets\Labeled;
use Rubix\ML\Datasets\Unlabeled;
/**
 * PHP ML
 */
class AdaBoost extends BaseClassifier implements TrainerInterface, ProbabilityEstimator
{
    public function __construct(array $options = [])
    {
        $resolver = new OptionsResolver();
        $this->configureOptions($resolver);
        $this->options = $resolver->resolve($options);
        if ($this->options['base'] and !$this->options['base']->type()->isClassifier()) {
            throw new CustomInvalidArgumentException('Base Estimator must be'
                . " a classifier, {$this->options['base']->type()} given.");
        }
        $this->classifier = new RubixAdaBoost(
            $this->options['base'],
            $this->options['rate'],
            $this->options['ratio'],
            $this->options['epochs'],
            $this->options['minChange'],
            $this->options['window'],
        );
    }

    /**
     *  ?Learner $base = null,
     *  float $rate = 1.0,
     *  float $ratio = 0.8,
     *  int $epochs = 100,
     *  float $minChange = 1e-4,
     *  int $window = 5
     *
     * @param OptionsResolver $resolver
     * @return void
     */
    public function configureOptions(OptionsResolver $resolver): void
    {
        $resolver->setDefaults([
            'base' => null,
            'rate' => 1.0,
            'ratio' => 0.8,
            'epochs' => 100,
            'minChange' => 1e-4,
            'window' => 5,
        ]);
    }

    public function train(array $samples, array $targets): void
    {
        try {
            $this->classifier->train(new Labeled($samples, $targets));
        } catch (\Exception $e) {
        }
    }

    public function predict(array $samples)
    {
        try {
            return $this->classifier->predict(new Unlabeled([$samples]));
        } catch (\Exception $e) {
            throw new CustomException('Could not predict');
        }
    }

    public function predictProbability(array $sample)
    {
        if ($this->options['probabilityEstimates']) {
            try {
                return $this->classifier->predictProbability($sample);
            } catch (\Exception $e) {
            }
        }
        throw new \InvalidArgumentException('To predict probabilities you must build a classifier with $probabilityEstimates set to true.');
    }

    public function steps()
    {
        return $this->classifier->steps();
    }


}